import pandas as pd
import statsmodels.formula.api as smf
import numpy as np

# Load the new dataset with dep_ratio and hh_type_strat
df = pd.read_csv('cleaned_with_male_head.csv')

# Categorical
df['CHILD'] = df['CHILD'].astype('category')
df['year'] = df['year'].astype('category')
df['H_TYPE'] = df['H_TYPE'].astype('category')
df['hh_type_strat'] = df['hh_type_strat'].astype('category')

# Demean (include dep_ratio)
def demean(df, vars_to_demean, group_col='code_id_hh'):
    df_demean = df.copy()
    for var in vars_to_demean:
        if var in df_demean.columns:
            group_mean = df_demean.groupby(group_col)[var].transform('mean')
            df_demean[var] = df_demean[var] - group_mean
    return df_demean

vars_to_demean = ['fert', 'htype', 'dep_ratio', 'mar', 'emp', 'pri', 'sec', 'high', 'inc', 'fam']
df_demean = demean(df, vars_to_demean)

# CI function
def get_param_info(model, param):
    coeff = model.params.get(param, np.nan)
    se = model.bse.get(param, np.nan)
    if np.isnan(coeff):
        return "N/A"
    ci_low = coeff - 1.96 * se
    ci_high = coeff + 1.96 * se
    return f"{round(coeff, 4)} ({round(ci_low, 4)} to {round(ci_high, 4)})"

# Function to build full table from model
def build_full_table(model, model_name):
    params = [
        'htype',
        'dep_ratio',
        'htype:dep_ratio',
        'C(CHILD)[T.1]',
        'C(CHILD)[T.2]',
        'C(CHILD)[T.3]',
        'mar',
        'emp',
        'pri',
        'sec',
        'high',
        'inc',
        'fam',
        'Intercept',
        'C(year)[T.2010.0]',
        'C(year)[T.2012.0]',
        'C(year)[T.2014.0]',
        'C(year)[T.2016.0]'
    ]
    variables = [
        'Living with grandparent(s)^a',
        'Dependency ratio',
        'Interaction: htype * dep_ratio',
        'Sex of children^b: At least 1 son',
        '- At least 1 daughter',
        '- At least 1 son and 1 daughter',
        'Proportion of married women',
        'Proportion of employed women',
        'Proportion of women with primary education',
        ' - Secondary education',
        ' - High school education',
        'Household income',
        'Household size',
        'Constant',
        'Year 2010',
        'Year 2012',
        'Year 2014',
        'Year 2016'
    ]
    data = {
        'Variable': variables,
        model_name: [get_param_info(model, p) for p in params]
    }
    table = pd.DataFrame(data)
    
    # Add footer rows for observations and R-squared
    footer = pd.DataFrame({
        'Variable': ['Observations', 'R-squared'],
        model_name: [int(model.nobs), round(model.rsquared, 3)]
    })
    full_table = pd.concat([table, footer], ignore_index=True)
    return full_table

# Interaction model (binary HTTYPE * dep_ratio)
formula_int = 'fert ~ htype * dep_ratio + C(CHILD) + mar + emp + pri + sec + high + inc + fam + C(year)'
model_int = smf.ols(formula_int, data=df_demean).fit(cov_type='cluster', cov_kwds={'groups': df_demean['code_id_hh']})

# Full Table 5c
table5c = build_full_table(model_int, 'No. of children under 3')
print("\nFull Table 5c: Binary Model with Dependency Ratio Interaction\n", table5c)
table5c.to_csv('revise_table_5c_dep.csv', index=False)